-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[offload] Add support for fp16 training #374
Conversation
Previous versions don't support amp I think. Fixing failures |
logging.info(f"Memory table {prof.key_averages().table()}") | ||
logging.info("Memory stats are " + str(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30)) | ||
logging.info( | ||
"Memory stats are {:.2f}GB".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dummy thought earlier in the day: it would be great to add some ballpark computation of the expected size at some point (given the batch size + model/shards), just for comparison
@@ -148,6 +150,7 @@ def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance | |||
return inputs if isinstance(inputs, tuple) else (inputs,) | |||
|
|||
@staticmethod | |||
@custom_bwd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you elaborate on that ? it's new to me
edit: I looked it up, sorry for the noise, makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, LGTM, thanks Anjali !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice. is there doc changes needed or updating the change log file?
I am committing to a feature branch which I will merge to master(very soon I think). I'll make sure to add doc changes and change log additions as needed. |
…s on 1 GPU. (#432) * clean start * removing per layer split strategy, probably not that useful indeed * initial transformer benchmark * hack, enable testing ViT + offload, python3 benchmarks/oss.py --epochs 2 --optim_type oss_offload_ddp --batch_size=32 --model vit_large_patch16_224 * proper cuda streams and device, something off in terms of mems consumption * minor, stashing * unit test fix * removing all the distributed parts * simpler test, needs debugging * working OOP, running a model which does not fit on the gpu memory * spring cleaning * removing the ill-advised optimizer bits, better keep that orthogonal * [offload] Add support for activation offloading + other changes (#367) * initial fwd/bwd commit * checkpoint work * modify shard loop * activation offloading and test to start with * fix lint errors * update comments * fix lint * remove unused var * remove commented out lines * modify name * remove break * remove profiler comments * avoid saving inputs * fix lint errors Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair> * [offload] Add support for fp16 training (#374) * initial fwd/bwd commit * checkpoint work * modify shard loop * activation offloading and test to start with * fix lint errors * update comments * fix lint * remove unused var * remove commented out lines * modify name * remove break * remove profiler comments * add support for fp16 * add unit tests * fix lint errors * fix test failure Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair> * [offload] Add support for activation checkpointing for all layers. (#381) * initial fwd/bwd commit * checkpoint work * modify shard loop * activation offloading and test to start with * fix lint errors * update comments * fix lint * remove unused var * remove commented out lines * modify name * remove break * remove profiler comments * add support for fp16 * add unit tests * fix lint errors * fix test failure * cp work, incorrect output dimensions still need to be fixed * fixed activation outputs * intermediate cp of work * add tests * fix lint errors Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair> * add support for microbatches * revert benchmark config changes * add parametrization * fix lint errors and tests * skip test for 1.5 * fix lint errors * skip test if there are no GPUs * fix lint errors * fix lint errors * move experimental to the fairscale repo * lint error fixes * modify test imports * lint error fixes * move offload files to the experimental directory * move tests and benchmarks to their forlder * fix mypy errors * cp intermediate working benchmarks * more changes * split benchmark configs * remove print statements * fix lint errors * remove unused print * stress testing * remove unused file * change param nae * lint fixes * move file to the right folder * offload_experimental * add doc string * add error message Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@gmail.com> Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@protonmail.com> Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>
Before submitting
What does this PR do?
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃